{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2f6e3d91",
   "metadata": {},
   "source": [
    "# 2-1,张量数据结构\n",
    "\n",
    "Pytorch的基本数据结构是张量Tensor。张量即多维数组。Pytorch的张量和numpy中的array很类似。\n",
    "\n",
    "本节我们主要介绍张量的数据类型、张量的维度、张量的尺寸、张量和numpy数组等基本概念。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "30005fa6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.__version__=2.4.0\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "print(\"torch.__version__=\"+torch.__version__) \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41f40abd",
   "metadata": {},
   "source": [
    "### 一，张量的数据类型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08b194dd",
   "metadata": {},
   "source": [
    "张量的数据类型和numpy.array基本一一对应，但是不支持str类型。\n",
    "\n",
    "包括:\n",
    "\n",
    "torch.float64(torch.double), \n",
    "\n",
    "**torch.float32(torch.float)**, \n",
    "\n",
    "torch.float16, \n",
    "\n",
    "torch.int64(torch.long), \n",
    "\n",
    "torch.int32(torch.int), \n",
    "\n",
    "torch.int16, \n",
    "\n",
    "torch.int8, \n",
    "\n",
    "torch.uint8, \n",
    "\n",
    "torch.bool\n",
    "\n",
    "一般神经网络建模使用的都是torch.float32类型。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "08ee241e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1) torch.int64\n",
      "tensor(2.) torch.float32\n",
      "tensor(True) torch.bool\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch \n",
    "\n",
    "# 自动推断数据类型\n",
    "\n",
    "i = torch.tensor(1);print(i,i.dtype)\n",
    "x = torch.tensor(2.0);print(x,x.dtype)\n",
    "b = torch.tensor(True);print(b,b.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "639539f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1, dtype=torch.int32) torch.int32\n",
      "tensor(2., dtype=torch.float64) torch.float64\n"
     ]
    }
   ],
   "source": [
    "# 指定数据类型\n",
    "\n",
    "i = torch.tensor(1,dtype = torch.int32);print(i,i.dtype)\n",
    "x = torch.tensor(2.0,dtype = torch.double);print(x,x.dtype)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b44d6e5d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0], dtype=torch.int32) torch.int32\n",
      "tensor(2.) torch.float32\n",
      "tensor([ True, False,  True, False]) torch.bool\n"
     ]
    }
   ],
   "source": [
    "# 使用特定类型构造函数\n",
    "\n",
    "i = torch.IntTensor(1);print(i,i.dtype)\n",
    "x = torch.Tensor(np.array(2.0));print(x,x.dtype) #等价于torch.FloatTensor\n",
    "b = torch.BoolTensor(np.array([1,0,2,0])); print(b,b.dtype)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "58181469",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1) torch.int64\n",
      "tensor(1.) torch.float32\n",
      "tensor(1.) torch.float32\n",
      "tensor(1.) torch.float32\n"
     ]
    }
   ],
   "source": [
    "# 不同类型进行转换\n",
    "\n",
    "i = torch.tensor(1); print(i,i.dtype)\n",
    "x = i.float(); print(x,x.dtype) #调用 float方法转换成浮点类型\n",
    "y = i.type(torch.float); print(y,y.dtype) #使用type函数转换成浮点类型\n",
    "z = i.type_as(x);print(z,z.dtype) #使用type_as方法转换成某个Tensor相同类型\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15bcb7cb",
   "metadata": {},
   "source": [
    "### 二，张量的维度"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ddbc02f",
   "metadata": {},
   "source": [
    "不同类型的数据可以用不同维度(dimension)的张量来表示。\n",
    "\n",
    "标量为0维张量，向量为1维张量，矩阵为2维张量。\n",
    "\n",
    "彩色图像有rgb三个通道，可以表示为3维张量。\n",
    "\n",
    "视频还有时间维，可以表示为4维张量。\n",
    "\n",
    "可以简单地总结为：有几层中括号，就是多少维的张量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c7185aec",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(True)\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "scalar = torch.tensor(True)\n",
    "print(scalar)\n",
    "print(scalar.dim())  # 标量，0维张量\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6d9668ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1., 2., 3., 4.])\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "vector = torch.tensor([1.0,2.0,3.0,4.0]) #向量，1维张量\n",
    "print(vector)\n",
    "print(vector.dim())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2f1419cc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 2.],\n",
      "        [3., 4.]])\n",
      "2\n"
     ]
    }
   ],
   "source": [
    "matrix = torch.tensor([[1.0,2.0],[3.0,4.0]]) #矩阵, 2维张量\n",
    "print(matrix)\n",
    "print(matrix.dim())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "912a056d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[1., 2.],\n",
      "         [3., 4.]],\n",
      "\n",
      "        [[5., 6.],\n",
      "         [7., 8.]]])\n",
      "3\n"
     ]
    }
   ],
   "source": [
    "tensor3 = torch.tensor([[[1.0,2.0],[3.0,4.0]],[[5.0,6.0],[7.0,8.0]]])  # 3维张量\n",
    "print(tensor3)\n",
    "print(tensor3.dim())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "657029ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[1., 1.],\n",
      "          [2., 2.]],\n",
      "\n",
      "         [[3., 3.],\n",
      "          [4., 4.]]],\n",
      "\n",
      "\n",
      "        [[[5., 5.],\n",
      "          [6., 6.]],\n",
      "\n",
      "         [[7., 7.],\n",
      "          [8., 8.]]]])\n",
      "4\n"
     ]
    }
   ],
   "source": [
    "tensor4 = torch.tensor([[[[1.0,1.0],[2.0,2.0]],[[3.0,3.0],[4.0,4.0]]],\n",
    "                        [[[5.0,5.0],[6.0,6.0]],[[7.0,7.0],[8.0,8.0]]]])  # 4维张量\n",
    "print(tensor4)\n",
    "print(tensor4.dim())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b47b791",
   "metadata": {},
   "source": [
    "### 三，张量的尺寸"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a3cd4b9",
   "metadata": {},
   "source": [
    "可以使用 shape属性或者 size()方法查看张量在每一维的长度。\n",
    "\n",
    "可以使用view方法改变张量的尺寸。\n",
    "\n",
    "如果view方法改变尺寸失败，可以使用reshape方法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3f2499fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([])\n",
      "torch.Size([])\n"
     ]
    }
   ],
   "source": [
    "scalar = torch.tensor(True)\n",
    "print(scalar.size())\n",
    "print(scalar.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "af6cc014",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4])\n",
      "torch.Size([4])\n"
     ]
    }
   ],
   "source": [
    "vector = torch.tensor([1.0,2.0,3.0,4.0])\n",
    "print(vector.size())\n",
    "print(vector.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c1426615",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 2])\n"
     ]
    }
   ],
   "source": [
    "matrix = torch.tensor([[1.0,2.0],[3.0,4.0]])\n",
    "print(matrix.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6216e5d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])\n",
      "torch.Size([12])\n",
      "tensor([[ 0,  1,  2,  3],\n",
      "        [ 4,  5,  6,  7],\n",
      "        [ 8,  9, 10, 11]])\n",
      "torch.Size([3, 4])\n",
      "tensor([[ 0,  1,  2],\n",
      "        [ 3,  4,  5],\n",
      "        [ 6,  7,  8],\n",
      "        [ 9, 10, 11]])\n",
      "torch.Size([4, 3])\n"
     ]
    }
   ],
   "source": [
    "# 使用view可以改变张量尺寸\n",
    "\n",
    "vector = torch.arange(0,12)\n",
    "print(vector)\n",
    "print(vector.shape)\n",
    "\n",
    "matrix34 = vector.view(3,4)\n",
    "print(matrix34)\n",
    "print(matrix34.shape)\n",
    "\n",
    "matrix43 = vector.view(4,-1) #-1表示该位置长度由程序自动推断\n",
    "print(matrix43)\n",
    "print(matrix43.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "9c5e90e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0,  1,  2,  3,  4,  5],\n",
      "        [ 6,  7,  8,  9, 10, 11]])\n",
      "torch.Size([2, 6])\n",
      "False\n",
      "tensor([[ 0,  6,  1,  7],\n",
      "        [ 2,  8,  3,  9],\n",
      "        [ 4, 10,  5, 11]])\n"
     ]
    }
   ],
   "source": [
    "# 有些操作会让张量存储结构扭曲，直接使用view会失败，可以用reshape方法\n",
    "\n",
    "matrix26 = torch.arange(0,12).view(2,6)\n",
    "print(matrix26)\n",
    "print(matrix26.shape)\n",
    "\n",
    "# 转置操作让张量存储结构扭曲\n",
    "matrix62 = matrix26.t()\n",
    "print(matrix62.is_contiguous())\n",
    "\n",
    "\n",
    "# 直接使用view方法会失败，可以使用reshape方法\n",
    "#matrix34 = matrix62.view(3,4) #error!\n",
    "matrix34 = matrix62.reshape(3,4) #等价于matrix34 = matrix62.contiguous().view(3,4)\n",
    "print(matrix34)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13519be6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "9bd6edcf",
   "metadata": {},
   "source": [
    "### 四，张量和numpy数组"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eeae56d7",
   "metadata": {},
   "source": [
    "可以用numpy方法从Tensor得到numpy数组，也可以用torch.from_numpy从numpy数组得到Tensor。\n",
    "\n",
    "这两种方法关联的Tensor和numpy数组是共享数据内存的。\n",
    "\n",
    "如果改变其中一个，另外一个的值也会发生改变。\n",
    "\n",
    "如果有需要，可以用张量的clone方法拷贝张量，中断这种关联。\n",
    "\n",
    "此外，还可以使用item方法从标量张量得到对应的Python数值。\n",
    "\n",
    "使用tolist方法从张量得到对应的Python数值列表。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8cea696f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "2c27e5ce",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before add 1:\n",
      "[0. 0. 0.]\n",
      "tensor([0., 0., 0.], dtype=torch.float64)\n",
      "\n",
      "after add 1:\n",
      "[1. 1. 1.]\n",
      "tensor([1., 1., 1.], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "#torch.from_numpy函数从numpy数组得到Tensor\n",
    "\n",
    "arr = np.zeros(3)\n",
    "tensor = torch.from_numpy(arr)\n",
    "print(\"before add 1:\")\n",
    "print(arr)\n",
    "print(tensor)\n",
    "\n",
    "print(\"\\nafter add 1:\")\n",
    "np.add(arr,1, out = arr) #给 arr增加1，tensor也随之改变\n",
    "print(arr)\n",
    "print(tensor)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "88fdbf45",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before add 1:\n",
      "tensor([0., 0., 0.])\n",
      "[0. 0. 0.]\n",
      "\n",
      "after add 1:\n",
      "tensor([1., 1., 1.])\n",
      "[1. 1. 1.]\n"
     ]
    }
   ],
   "source": [
    "# numpy方法从Tensor得到numpy数组\n",
    "\n",
    "tensor = torch.zeros(3)\n",
    "arr = tensor.numpy()\n",
    "print(\"before add 1:\")\n",
    "print(tensor)\n",
    "print(arr)\n",
    "\n",
    "print(\"\\nafter add 1:\")\n",
    "\n",
    "#使用带下划线的方法表示计算结果会返回给调用 张量\n",
    "tensor.add_(1) #给 tensor增加1，arr也随之改变 \n",
    "#或： torch.add(tensor,1,out = tensor)\n",
    "print(tensor)\n",
    "print(arr)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "13fb0e0e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before add 1:\n",
      "tensor([0., 0., 0.])\n",
      "[0. 0. 0.]\n",
      "\n",
      "after add 1:\n",
      "tensor([1., 1., 1.])\n",
      "[0. 0. 0.]\n"
     ]
    }
   ],
   "source": [
    "# 可以用clone() 方法拷贝张量，中断这种关联\n",
    "\n",
    "tensor = torch.zeros(3)\n",
    "\n",
    "#使用clone方法拷贝张量, 拷贝后的张量和原始张量内存独立\n",
    "arr = tensor.clone().numpy() # 也可以使用tensor.data.numpy()\n",
    "print(\"before add 1:\")\n",
    "print(tensor)\n",
    "print(arr)\n",
    "\n",
    "print(\"\\nafter add 1:\")\n",
    "\n",
    "#使用 带下划线的方法表示计算结果会返回给调用 张量\n",
    "tensor.add_(1) #给 tensor增加1，arr不再随之改变\n",
    "print(tensor)\n",
    "print(arr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "72f9b26a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0\n",
      "<class 'float'>\n",
      "[[0.24736273288726807, 0.7015318870544434], [0.6680377125740051, 0.03158283233642578]]\n",
      "<class 'list'>\n"
     ]
    }
   ],
   "source": [
    "# item方法和tolist方法可以将张量转换成Python数值和数值列表\n",
    "scalar = torch.tensor(1.0)\n",
    "s = scalar.item()\n",
    "print(s)\n",
    "print(type(s))\n",
    "\n",
    "tensor = torch.rand(2,2)\n",
    "t = tensor.tolist()\n",
    "print(t)\n",
    "print(type(t))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46a04e72",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a10773d7",
   "metadata": {},
   "source": [
    "**如果本书对你有所帮助，想鼓励一下作者，记得给本项目加一颗星星star⭐️，并分享给你的朋友们喔😊!** \n",
    "\n",
    "如果对本书内容理解上有需要进一步和作者交流的地方，欢迎在公众号\"算法美食屋\"下留言。作者时间和精力有限，会酌情予以回复。\n",
    "\n",
    "也可以在公众号后台回复关键字：**加群**，加入读者交流群和大家讨论。\n",
    "\n",
    "![算法美食屋logo.png](https://tva1.sinaimg.cn/large/e6c9d24egy1h41m2zugguj20k00b9q46.jpg)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "ipynb,md",
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
